%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% GE_IRF_Baseline.m
% Simulate a GE heterogeneous firms investment model with macro shocks and real adjustment costs 
%
% Ivan Alfaro, Nick Bloom and Xiaoji Lin 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

clc; clearvars;

load ks_het_firms_real_only.mat; 

%control the IRF

numper          = 2000;  % period for initial point of K using the simulation of the model simulation
numsimIRF       = 500;   % number of shocked economies to simulate
lengthIRF       = 30;    % length of after shock in each economy
shockperIRF     = 200;   % period in which to shock the economy
checkbounds     = 1;

numperIRF       = numsimIRF*lengthIRF; %total number of periods in IRF simulation
Ttot            = lengthIRF + shockperIRF;
maxpit          = 50;
pwindow         = 0.1;
pcutoff         = 15;

randseed = 2501; %random number generator seed, for reproducibility
rng(randseed);

% rng('default');
% rng(1)

%then, initialize the aggregate series to 0
YsimIRF  = zeros(Ttot,numsimIRF);      
KsimIRF  = YsimIRF; NsimIRF = YsimIRF; IsimIRF = YsimIRF; HsimIRF    = YsimIRF; 
ACsimIRF = YsimIRF; CsimIRF = YsimIRF; psimIRF = YsimIRF;  perrsimIRF = YsimIRF;  
DivsimIRF= YsimIRF; K0simIRF= YsimIRF; KZsimIRF= YsimIRF;  
z1         = permute(repmat(z0,[1 snum knum  knum  pnum]),[4 1 2 3 5]); %(k',z,s,k,p) 
ivaltemp = kprime1 - (1-delta)*k1;                             %(k',z,s,k,p)
ivaltemp(abs(ivaltemp)<=ZeroBoundary) = 0;                     %(k',z,s,k,p)
ival     = ivaltemp; 			
                  
%start with aggregate capital guessed at some reasonable value from uncond simulation
KsimIRF(1,:) = Ksim(numper);                %K0(Kinit);
distzskIRF   = zeros(znum,snum,knum,Ttot);
TOPkmax      = zeros(numsimIRF,1);
BOTkmax      = zeros(numsimIRF,1); 

distzskIRFsimct =zeros(znum,snum,knum,Ttot,numsimIRF);

AsimposIRF         = zeros(Ttot,numsimIRF);
AsimposIRF(1,:)    = Ainit;
AshocksIRF         = rand(Ttot,numsimIRF);  %repmat(Ashocks,[1 numsimIRF]);%
sigmaAsimposIRF    = zeros(Ttot,numsimIRF);
sigmaAsimposIRF(1,:) = Sinit;
sigmaAshocksIRF    = rand(Ttot,numsimIRF);  %repmat(sigmaAshocks,[1 numsimIRF]); %


for simct = 1:numsimIRF
    %first, initialize the distribution over endogenous variables
    
    distzskIRF           = zeros(znum,snum,knum,Ttot);
    distzskIRF(:,:,:,1)  = distzsk(:,:,:,numper);%distoldIRF; 
    distzskIRF(:,:,:,1)  = distzskIRF(:,:,:,1)/ sum(sum(sum(distzskIRF(:,:,:,1))));
 
        for t =1:Ttot-1
            if t==shockperIRF
                 sigmaAsimposIRF(t,simct) = 2;
                 Sct      = sigmaAsimposIRF(t,simct); 
                 Sprimect = find(PisumSigma(Sct,:) >= sigmaAshocksIRF(t+1,simct), 1); 	    %compare today's uniform shock to transition matrix intervals
                 sigmaAsimposIRF(t+1,simct) = Sprimect;                                     %store simulated value
                 % agg prodcutivity
                 Act       = AsimposIRF(t,simct);
                 Aprimect  = find(PisumA(Act,:,Sct) >= AshocksIRF(t+1,simct), 1);           %compare today's uniform shock to transition matrix intervals
                 AsimposIRF(t+1,simct) = Aprimect;                                          %store simulated value
            else
                 % time-varying uncertainty evolve normally
                 Sct       = sigmaAsimposIRF(t,simct);
                 Sprimect  = find(PisumSigma(Sct,:) >= sigmaAshocksIRF(t+1,simct), 1);      %compare today's uniform shock to transition matrix intervals
                 sigmaAsimposIRF(t+1,simct)= Sprimect;                                      %store simulated value
                 % agg prodcutivity
                 Act       = AsimposIRF(t,simct);
                 Aprimect  = find(PisumA(Act,:,Sct) >= AshocksIRF(t+1,simct), 1);           %compare today's uniform shock to transition matrix intervals
                 AsimposIRF(t+1,simct) = Aprimect;                                          %store simulated value
            end
        end
  

% disp('Simulating the model for impulse response functions.')

    Act = 0;Sct = 0;
    for t=1:(Ttot-1)
         %extract macro prod
            Act       = AsimposIRF(t,simct);
            Aval      = A0(Act);
            Sct       = sigmaAsimposIRF(t,simct);
            sigmaAval = sigmaA(Sct);
         %extract macro capital and forecasts
            Kval      = KsimIRF(t,simct);
            Kprimeval = exp(thetaKold(Act,Sct,1)+ thetaKold(Act,Sct,2)*log(Kval));
            Kprimeind = sum(Kprimeval>=K0);

            % guard against off grid point values
            if (Kprimeind==Knum) 
                Kprimeind = Knum-1;
                Kprimewgt = 1.0;
            elseif (Kprimeind==0)
                Kprimeind = 1;
                Kprimewgt = 0.0;
            else
               Kprimewgt  = (Kprimeval-K0(Kprimeind))/(K0(Kprimeind+1)-K0(Kprimeind));
            end 
            pfcstval    = exp(thetapold(Act,Sct,1) + thetapold(Act,Sct,2)*log(Kval));

          yval     = Aval*k1_z_alpha;                                    %(k',z,s,k,p)        
          ACval    = c_k*(ival~=0).*yval;                                %(k',z,s,k,p)
          divval   = yval - ival - ACval;                               %(k',z,s,k,p)
          Rmatp    = p0mat.*divval;                                      %(k',z,s,k,p)
          zval     = z1;
          kz       = z1.*k1;
          
        if t==shockperIRF
            
             sct_l = 1; sct_h = 2;
             distzskIRF_new = zeros(znum,snum,knum);
             for zct=1:znum %z
                 for kct=1:knum %k             
                     if distzskIRF(zct,sct_l,kct,t)>0
                         %move the sigma_l distribution to sigma_h; the new
                         %distribution has a mass of 1 in sigma_h, a mass
                         %of 0 in sigma_l
                         distzskIRF_new(zct,sct_h,kct) = distzskIRF(zct,sct_l,kct,t)+distzskIRF(zct,sct_h,kct,t); 
                     end
                 end
             end
             
             distzskIRFtemp = repmat(distzskIRF_new,[1 1 1 pnum]);

            EVMAT3    = zeros(knum,znum);
            temp = (1-Kprimewgt)*squeeze(V(:,:,Kprimeind,:,:,:))...
                                    + Kprimewgt*squeeze(V(:,:,Kprimeind+1,:,:,:));             %(A',S',z',s',k')                                
            temp1 = PiSigma(Sct,:)*reshape(PiA(Act,:,Sct)*reshape(temp,[Anum,Snum*znum*snum*knum]),...
                                            [Snum,znum*snum*knum]);  

            for zct = 1:znum 
                    EVMAT3(:,zct) = Pisigma(sct_h,:)*reshape(Piz(zct,:,sct_h)*reshape(temp1,[znum,snum*knum]),[snum,knum]); 
                                %(z',s',k') to (z',s'*k') to (1,s',k') to    %(k',1)
            end
            EVMAT = repmat(EVMAT3,[1 1 snum]);
            
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
             yval  = repmat(yval(:,:,sct_h,:,:),[1 1 snum 1 1]);          %(k',z,s,k,p)
             zval  = repmat(z1(:,:,sct_h,:,:),[1 1 snum 1 1]);           %(k',z,s,k,p)
             kzval  = repmat(kz(:,:,sct_h,:,:),[1 1 snum 1 1]);           %(k',z,s,k,p)
             Rmatp2 = repmat(Rmatp(:,:,sct_h,:,:),[1 1 snum 1 1]);        %(k',z,s,k,p)
             
             EVMAT_H = repmat(EVMAT,[1 1 1 knum pnum]);                 %(k',z,s,k,p)  
             RHSMAT2 = Rmatp2 + beta*EVMAT_H;
             
             [~,kprimeinddum] = max(RHSMAT2,[],1);
             kprimeindp = squeeze(kprimeinddum);
             kprimep    = k0(kprimeindp);                               %(z,s,k,p)

             yvalp      = squeeze(yval(1,:,:,:,:));                     %(z,s,k,p)
             zvalp      = squeeze(zval(1,:,:,:,:));                     %(z,s,k,p)
             kzvalp      = squeeze(kzval(1,:,:,:,:));                     %(z,s,k,p)
             ivalptemp  = (kprimep - (1-delta)*squeeze(k1(1,:,:,:,:)));
             ivalptemp(abs(ivalptemp)<=ZeroBoundary)=0; 
             ivalp      = repmat(ivalptemp(:,sct_h,:,:),[1 snum 1 1]);        %(k',z,s,k,p)
             acvalp     = c_k*(ivalp~=0).*yvalp;
             divvalp    = (yvalp-ivalp-acvalp);
             
             kprimep    = repmat(kprimep(:,sct_h,:,:),[1 snum 1 1]);   
             polmatp_interp = repmat(kprimeindp(:,sct_h,:,:),[1 snum 1 1]);   %(z,s,k,n,p)
                          
             Yp0        = squeeze(sum(sum(sum(yvalp.*distzskIRFtemp))));
             KZp0        = squeeze(sum(sum(sum(kzvalp.*distzskIRFtemp))));
             Ip0        = squeeze(sum(sum(sum(ivalp.*distzskIRFtemp))));
             ACp0       = squeeze(sum(sum(sum(acvalp.*distzskIRFtemp))));
             Kbarprimep0= squeeze(sum(sum(sum(kprimep.*distzskIRFtemp))));
             Divp0      = squeeze(sum(sum(sum(divvalp.*distzskIRFtemp))));
             Cvalp      = Yp0 - Ip0 - ACp0;
             Cp0        = Cvalp;
             ep0        = 1./p0'-Cp0;           

% 
        else % normal conditional expectation   
            
            EVMAT    = zeros(knum,znum,sct);
            temp = (1.0-Kprimewgt)*squeeze(V(:,:,Kprimeind,:,:,:))...
                                    + Kprimewgt*squeeze(V(:,:,Kprimeind+1,:,:,:));             %(A',S',z',s',k')                                
             temp1 = PiSigma(Sct,:)*reshape(PiA(Act,:,Sct)*reshape(temp,[Anum,Snum*znum*snum*knum]),...
                                            [Snum,znum*snum*knum]);  %(1,S',z',s',k') to %(S',z'*s'*k') to (z',s',k')
            for zct = 1:znum 
                for sct =1:snum   
                    EVMAT(:,zct,sct) = Pisigma(sct,:)*reshape(Piz(zct,:,sct)*reshape(temp1,[znum,snum*knum]),[snum,knum]); 
                                %(z',s',k') to (z',s'*k') to (1,s',k') to    %(k',1)
                 end
            end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
           disttemp = repmat(distzskIRF(:,:,:,t),[1 1 1 pnum]);%(z,s,k,p)
          EVMAT_H = repmat(EVMAT,[1 1 1 knum pnum]);                     %(k',z,s,k,p)  
          RHSMAT2 = Rmatp + beta*EVMAT_H;
          [~,kprimeinddum] = max(RHSMAT2,[],1);
          kprimeindp = squeeze(kprimeinddum);
          kprimep    = k0(kprimeindp);
                   
           yvalp     = squeeze(yval(1,:,:,:,:));
           zvalp     = squeeze(z1(1,:,:,:,:));
           kzvalp     = squeeze(kz(1,:,:,:,:));
           ivalptemp = (kprimep - (1-delta)*squeeze(k1(1,:,:,:,:)));
           ivalptemp(abs(ivalptemp)<=ZeroBoundary)=0; 
           ivalp     = ivalptemp;
           %ivalp(abs(ivalptemp)>ZeroBoundary)=ivalp(abs(ivalptemp)>ZeroBoundary); 
           acvalp    = c_k*(ivalp~=0).*yvalp;
           divvalp   = (yvalp-ivalp-acvalp) ;
                          
           Yp0        = squeeze(sum(sum(sum(yvalp.*disttemp))));
           KZp0       = squeeze(sum(sum(sum(kzvalp.*disttemp))));
           Ip0        = squeeze(sum(sum(sum(ivalp.*disttemp))));
           ACp0       = squeeze(sum(sum(sum(acvalp.*disttemp))));
           Kbarprimep0= squeeze(sum(sum(sum(kprimep.*disttemp))));
           Divp0      = squeeze(sum(sum(sum(divvalp.*disttemp))));
           polmatp_interp = kprimeindp; %(z,s,k,n,p)
           Cvalp = Yp0-Ip0-ACp0;
           Cp0   = Cvalp;
           ep0   = 1./p0'-Cp0;           
     
        end
           
      %set up the boundaries of the bisection
        pvala = plb;
        pvalb = pub;
        pvalc = pvala + (0.67)*(pvalb-pvala);

        %iterate over the market-clearing value of p, for each value
        %redoing the optimization of the RHS of the Bellman equation
        %and recomputing policies
        for piter=1:maxpit
            %initialize the price-dependent policies to 0
    %         kprimeinddum(:,:) = 0; kprimeindp(:,:) = 0; 

            %brent setup
            if (piter==1); pval = pvala;end
            if (piter==2); pval = pvalb;end
            if (piter==3); pval = pvalc;end

            %brent restart
            if (piter==pcutoff)
                pvala = pfcstsim(t)-4.0*pwindow;
                pval = pvala;
            elseif (piter==pcutoff+1) 
                pvalb =  pfcstsim(t)+4.0*pwindow;
                pval = pvalb;
            elseif (piter==pcutoff+2)
                pvalc = pvala + (0.67) * (pvalb-pvala); 
                pval = pvalc;
            end 

            %if not restarting or initializing
            if ((piter>3 && piter<pcutoff) || (piter>pcutoff+2)) 
                %first, try inverse quadratic interpolation of the excess demand function
                pval = ( pvala * fb * fc ) / ( (fa - fb) * (fa - fc) ) ...
                    + ( pvalb * fa * fc ) / ( (fb - fa) * (fb - fc ) ) ...
                    + ( pvalc * fa * fb ) / ( (fc - fa) * (fc - fb ) );
                %if it lies within bounds, and isn't too close to the bounds, then done
                %o/w, take bisection step
                if (min(abs(pvala - pval),abs(pvalb-pval))<abs((pvalb-pvala)/9)) || (pval<pvala) || (pval>pvalb)
    %                 ((minval( (/ abs(pvala - pval), abs(pvalb-pval) /) )<&
    %                     abs( (pvalb-pvala)/dble(9.0) ) ).or.(pval<pvala).or.(pval>pvalb))   then
                    pval = (pvala + pvalb) /2;
                end
            end

            %actually evaluate consumption approximation function
            pval = min(max(p0(1),pval),p0(pnum));         %minval( (/ maxval( (/ p0(1) , pval /) ) , p0(pnum) /) )   
            pind = sum(pval>=p0);%pind = pnum/2;  call hunt(p0,pnum,pval,pind)
    % 	    pwgt = (pval - p0(pind)) / (p0(pind + 1) - p0(pind));

            % guard against off grid point values
            if (pind==pnum) 
                pind = pnum-1;
                pwgt = 1.0;
            elseif (pind==0)
                pind = 1;
                pwgt = 0.0;
            else
                pwgt = (pval - p0(pind)) / (p0(pind + 1) - p0(pind));
            end 

            %actually perform linear interpolation of the consumption function
            CvalpNew = Cp0(pind)*(1-pwgt) + Cp0(pind+1)*pwgt;
            CvalpNew = max(CvalpNew,kmin);

            %are you initializing the brent?
            if (piter==1); fa = (1/pval) - CvalpNew;end
            if (piter==2); fb = (1/pval) - CvalpNew;end
            if (piter==3); fc = (1/pval) - CvalpNew;end

            %are you restarting the brent?
            if (piter==pcutoff);   fa = (1/pval) - CvalpNew; end
            if (piter==pcutoff+1); fb = (1/pval) - CvalpNew; end
            if (piter==pcutoff+2); fc = (1/pval) - CvalpNew; end

            %what is the error given by this implied consumption?
            perror = (1/pval) - CvalpNew;

            %if not restarting or initializing
            if ((piter>3 && piter<pcutoff)||(piter>pcutoff+2))  
                if (perror<0) 
                    pvalc = pvalb; fc = fb;
                    pvalb = pval; fb = perror;
                    %pval a doesn't change
                elseif (perror>=0) 
                    pvalc = pvala; fc = fa;
                    pvala = pval; fa = perror;
                    %pval b doesn't change
                end
            end

            %exit criterion for market-clearing
            perror = log(pval*CvalpNew);
            if (abs(perror)<perrortol && piter>2) 
                break
            end  %piter
        end
              %insert market-clearing price and other linearly interpolated stuff into sim series
            perrsimIRF(t,simct) = perror;
            psimIRF(t,simct)    = pval;         %this is the most recently run p from the clearing algorithm
            CsimIRF(t,simct)    = CvalpNew;     %this is already the linearly interpolated consumption C
            YsimIRF(t,simct)    = Yp0(pind)*(1.0-pwgt)   + Yp0(pind+1)*pwgt; %linearly interpolated output Y
            IsimIRF(t,simct)    = Ip0(pind)*(1.0-pwgt)   + Ip0(pind+1)*pwgt; %linearly interpolated investment I
            DivsimIRF(t,simct)  = Divp0(pind)*(1.0-pwgt) + Divp0(pind+1)*pwgt; %linearly interpolated investment Div
            ACsimIRF(t,simct)   = ACp0(pind)*(1.0-pwgt)  + ACp0(pind+1)*pwgt; %linearly interpolated AC 
            KsimIRF(t+1,simct)  = Kbarprimep0(pind)*(1.0-pwgt) + Kbarprimep0(pind+1)*pwgt; %linearly interpolated K' 
            KZsimIRF(t,simct)   = KZp0(pind)*(1.0-pwgt) + KZp0(pind+1)*pwgt; %linearly interpolated KZ 

            %output price stats on certain periods
              if (mod(t,shockperIRF-10)==1)
                    disp(['IRF = ',num2str(simct), ' t =  ',num2str(t),' iter  ',num2str(piter),' p ',num2str(pval),' err ',num2str(perrsimIRF(t,simct)), ' A ',num2str(A0(Act)),' Kprime ' ,num2str(KsimIRF(t+1,simct))])        
              end

              %if K hits boundary, reset the next period distribution
              if  KsimIRF(t+1,simct)<=kmin ||  KsimIRF(t+1,simct)>=kmax      
                     KsimIRF(t+1,simct)   = min([max([  KsimIRF(t+1,simct) Kmin+0.001]),Kmax-0.001]);
                     KsimIRF(t+1,simct)   = Ksim(numper);
                    distzskIRF(:,:,:,t)   = distzsk(:,:,:,numper);
              end
         
   
         if t==shockperIRF
                    
                    sct_h = 2;
                    for zct=1:znum
                       for sct =1:snum
                        for kct=1:knum %endogct=1,numendog
                            
                          if (distzskIRF_new(zct,sct,kct)>disttol)       

                            %based on the latest price, what is the policy here at pind?
                            polstar = polmatp_interp(zct,sct,kct,pind);

                            %insert distributional weight in appropriate slots next period
                            distzskIRF(:,:,polstar,t+1) = distzskIRF(:,:,polstar,t+1) + ...
                                repmat(Piz(zct,:,sct_h)',1,snum).*repmat(Pisigma(sct_h,:),znum,1)*distzskIRF_new(zct,sct,kct) * (1.0-pwgt);

                            %based on the latest price, what is the policy here at pind+1?
                            polstar = polmatp_interp(zct,sct,kct,pind+1);

                            %insert distributional weight in appropriate slots next period
                            distzskIRF(:,:,polstar,t+1) = distzskIRF(:,:,polstar,t+1) + ...
                                repmat(Piz(zct,:,sct_h)',1,snum).*repmat(Pisigma(sct_h,:),znum,1)*distzskIRF_new(zct,sct,kct) * pwgt;
   
                          end 
                        end  %kct
                       end %sct
                   end  %zct
                   
             
         else
                    for zct=1:znum
                         for sct =1:snum
                            for kct=1:knum %endogct=1,numendog
                                if (distzskIRF(zct,sct,kct,t)>disttol)       

                                %based on the latest price, what is the policy here at pind?
                                polstar = polmatp_interp(zct,sct,kct,pind);

                                %insert distributional weight in appropriate slots next period
                                distzskIRF(:,:,polstar,t+1) = distzskIRF(:,:,polstar,t+1) + ...
                                    repmat(Piz(zct,:,sct)',1,snum).*repmat(Pisigma(sct,:),znum,1)*distzskIRF(zct,sct,kct,t) * (1.0-pwgt);

                                %based on the latest price, what is the policy here at pind+1?
                                polstar = polmatp_interp(zct,sct,kct,pind+1);

                                %insert distributional weight in appropriate slots next period
                                distzskIRF(:,:,polstar,t+1) = distzskIRF(:,:,polstar,t+1) + ...
                                    repmat(Piz(zct,:,sct)',1,snum).*repmat(Pisigma(sct,:),znum,1)*distzskIRF(zct,sct,kct,t) * pwgt;
                       
                                end 
                            end  %kct
                         end %sct
                    end  %zct
              
         end %if

            %now that the market-clearing price is determined, move on to insert weight into the next period, according to the 
            %linearly interpolated rule

    %     %now, round to make sure that you're ending up with a distribution which makes sense each period
         distzskIRF(:,:,:,t+1) = distzskIRF(:,:,:,t+1)./sum(sum(sum(distzskIRF(:,:,:,t+1))));
         
    end % t
    
   
end %simct

TFPsimIRF = KZsimIRF./KsimIRF;

clear ACval acvalp Cvalp Cvalp divval divvalp EVMAT EVMAT_H EVMAT_H3 ival ivalp ivalptemp ivaltemp k1 k1_ly
clear kprime1 kprime_ind Kprimeeest kprimeinddum kprimep nval nvalp p0mat p0z_l p0z_y p1 yval yvalp


save IRF_Real_Only.mat

